{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Tce3stUlHN0L"
      },
      "source": [
        "##### Copyright 2024 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tuOe1ymfHZPu"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dfsDR_omdNea"
      },
      "source": [
        "# Using Gemma  with LangChain\n",
        "This notebook demonstrates how to use Gemma (2B) model with LangChain library.\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_2]Using_with_LangChain.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FaqZItBdeokU"
      },
      "source": [
        "## Setup\n",
        "\n",
        "### Select the Colab runtime\n",
        "To complete this tutorial, you'll need to have a Colab runtime with sufficient resources to run the Gemma model. In this case, you can use a T4 GPU:\n",
        "\n",
        "1. In the upper-right of the Colab window, select **▾ (Additional connection options)**.\n",
        "2. Select **Change runtime type**.\n",
        "3. Under **Hardware accelerator**, select **T4 GPU**.\n",
        "\n",
        "### Gemma setup\n",
        "\n",
        "To complete this tutorial, you'll first need to complete the setup instructions at [Gemma setup](https://ai.google.dev/gemma/docs/setup). The Gemma setup instructions show you how to do the following:\n",
        "\n",
        "* Get access to Gemma on kaggle.com.\n",
        "* Select a Colab runtime with sufficient resources to run\n",
        "  the Gemma 2B model.\n",
        "* Generate and configure a Kaggle username and an API key as Colab secrets.\n",
        "\n",
        "After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CY2kGtsyYpHF"
      },
      "source": [
        "### Configure your credentials\n",
        "\n",
        "Add your your Kaggle credentials to the Colab Secrets manager to securely store it.\n",
        "\n",
        "1. Open your Google Colab notebook and click on the 🔑 Secrets tab in the left panel. <img src=\"https://storage.googleapis.com/generativeai-downloads/images/secrets.jpg\" alt=\"The Secrets tab is found on the left panel.\" width=50%>\n",
        "2. Create new secrets: `KAGGLE_USERNAME` and `KAGGLE_KEY`\n",
        "3. Copy/paste your username into `KAGGLE_USERNAME`\n",
        "3. Copy/paste your key into `KAGGLE_KEY`\n",
        "4. Toggle the buttons on the left to allow notebook access to the secrets.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A9sUQ4WrP-Yr"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "from google.colab import userdata\n",
        "\n",
        "# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env\n",
        "# vars as appropriate for your system.\n",
        "os.environ[\"KAGGLE_USERNAME\"] = userdata.get(\"KAGGLE_USERNAME\")\n",
        "os.environ[\"KAGGLE_KEY\"] = userdata.get(\"KAGGLE_KEY\")\n",
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iwjo5_Uucxkw"
      },
      "source": [
        "### Install dependencies\n",
        "Run the cell below to install all the required dependencies."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "r_nXPEsF7UWQ"
      },
      "outputs": [],
      "source": [
        "!pip install -q -U tensorflow\n",
        "!pip install -q -U keras keras-nlp"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J3sX2mFH4GWk"
      },
      "source": [
        "### Gemma"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fz47tAgSKMNH"
      },
      "source": [
        "**About Gemma**\n",
        "\n",
        "Gemma is a family of lightweight, state-of-the-art open models from Google, built from the same research and technology used to create the Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. Gemma models are well-suited for a variety of text generation tasks, including question answering, summarization, and reasoning. Their relatively small size makes it possible to deploy them in environments with limited resources such as a laptop, desktop or your own cloud infrastructure, democratizing access to state of the art AI models and helping foster innovation for everyone.\n",
        "\n",
        "**Prompt formatting**\n",
        "\n",
        "Instruction-tuned (IT) models are trained with a specific formatter that annotates all instruction tuning examples with extra information, both at training and inference time. The formatter has two purposes:\n",
        "\n",
        "* Indicating roles in a conversation, such as the system, user, or assistant roles.\n",
        "* Delineating turns in a conversation, especially in a multi-turn conversation.\n",
        "\n",
        "Below is the control tokens used by Gemma and their use cases. Note that the control tokens are reserved in and specific to our tokenizer.\n",
        "\n",
        "* Token to indicate a user turn: `user`\n",
        "* Token to indicate a model turn: `model`\n",
        "* Token to indicate the beginning of dialogue turn: `<start_of_turn>`\n",
        "* Token to indicate the end of dialogue turn: `<end_of_turn>`\n",
        "\n",
        "Here's the [official documentation](https://ai.google.dev/gemma/docs/formatting) regarding prompting instruction-tuned models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6NfjcM-Qn1bs"
      },
      "outputs": [],
      "source": [
        "!pip install -q langchain langchain-google-vertexai\n",
        "!pip install -q langchainhub langchain-chroma langchain_community langchain-huggingface\n",
        "!pip install -q sentence-transformers==2.2.2\n",
        "!pip install -q -U huggingface_hub"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V4cpi3sIoQ2O"
      },
      "outputs": [],
      "source": [
        "# Load Gemma using LangChain library\n",
        "from langchain_google_vertexai import GemmaChatLocalKaggle\n",
        "\n",
        "keras_backend: str = \"jax\"\n",
        "model_name = \"gemma2_instruct_2b_en\"\n",
        "llm = GemmaChatLocalKaggle(\n",
        "    model_name=model_name,\n",
        "    model=model_name,\n",
        "    keras_backend=keras_backend,\n",
        "    max_tokens=1024,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6I7VA43IzwTd"
      },
      "source": [
        "# QA with RAG"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IijuA-i-D5Pi"
      },
      "source": [
        "Retrieval-Augmented Generation (RAG) is a key advancement for Large Language Models (LLMs) for a couple of reasons:\n",
        "\n",
        "- Boosts Factual Accuracy: LLMs are trained on massive amounts of text data, but this data can be outdated or incomplete. RAG tackles this by allowing the LLM to access and incorporate relevant information from external sources during generation. This external fact-checking helps reduce  made-up information, or \"hallucinations,\" in the LLM's outputs, making them more trustworthy.\n",
        "\n",
        "- Enhances Relevance and Depth: RAG provides LLMs with a wider range of knowledge to draw on. When responding to a prompt or question, the LLM can not only use its internal knowledge but also supplement it with specific details retrieved from external data sources. This leads to more comprehensive and informative responses that are precisely tailored to the situation.\n",
        "\n",
        "Overall, RAG elevates the credibility and usefulness of LLMs by ensuring their outputs are grounded in factual information and highly relevant to the context. This is crucial for applications like chatbots, educational tools, and even creative writing where factual grounding and rich detail are important."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a-_U-k6mX6JF"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:langchain_community.utils.user_agent:USER_AGENT environment variable not set, consider setting it to identify your requests.\n"
          ]
        }
      ],
      "source": [
        "import bs4\n",
        "from langchain_huggingface.embeddings import HuggingFaceEmbeddings\n",
        "from langchain_core.output_parsers import BaseTransformOutputParser\n",
        "from langchain import hub\n",
        "from langchain_chroma import Chroma\n",
        "from langchain_community.document_loaders import WebBaseLoader\n",
        "from langchain_core.output_parsers import StrOutputParser\n",
        "from langchain_core.runnables import RunnablePassthrough\n",
        "from langchain_text_splitters import RecursiveCharacterTextSplitter"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ETtaWSGMh5Q9"
      },
      "outputs": [],
      "source": [
        "# Helpers\n",
        "class GemmaOutputParser(BaseTransformOutputParser[str]):\n",
        "    \"\"\"OutputParser that parses LLM response and extract\n",
        "    the generated part.\"\"\"\n",
        "\n",
        "    @classmethod\n",
        "    def is_lc_serializable(cls) -> bool:\n",
        "        \"\"\"Return whether this class is serializable.\"\"\"\n",
        "        return True\n",
        "\n",
        "    @property\n",
        "    def _type(self) -> str:\n",
        "        \"\"\"Return the output parser type for serialization.\"\"\"\n",
        "        return \"gemma_2_parser\"\n",
        "\n",
        "    def parse(self, text: str) -> str:\n",
        "        \"\"\"Return the input text with no changes.\"\"\"\n",
        "        model_start_token = \"<start_of_turn>model\\n\"\n",
        "        idx = text.rfind(model_start_token)\n",
        "        return text[idx + len(model_start_token) :] if idx > -1 else \"\"\n",
        "\n",
        "\n",
        "def format_docs(docs):\n",
        "    return \"\\n\\n\".join(doc.page_content for doc in docs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RsdZhKsJ-1nD"
      },
      "source": [
        "## Creating vector store"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "310K9XY0-oKU"
      },
      "source": [
        "You will use [this blog post](https://developers.google.com/machine-learning/resources/intro-llms) as a data source for our application. In this section you will fetch the data, chunk it and load it into our vector store."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HrxZCfK2UeS9"
      },
      "outputs": [],
      "source": [
        "# Load, chunk and index the contents of the blog.\n",
        "loader = WebBaseLoader(\n",
        "    web_paths=(\"https://developers.google.com/machine-learning/resources/intro-llms\",),\n",
        "    bs_kwargs=dict(\n",
        "        parse_only=bs4.SoupStrainer(name=(\"h3\", \"p\"))\n",
        "    ),\n",
        ")\n",
        "docs = loader.load()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ben7d6YPY6yj"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_deprecation.py:131: FutureWarning: 'cached_download' (from 'huggingface_hub.file_download') is deprecated and will be removed from version '0.26'. Use `hf_hub_download` instead.\n",
            "  warnings.warn(warning_message, FutureWarning)\n"
          ]
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4f8b08607e5442388a3d8a87df2a3af9",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              ".gitattributes:   0%|          | 0.00/1.23k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "41979bc368114a68accdf4dd186bb9df",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "4ef15978bc544270818dd4c9e7dbd81a",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "91aec707b72641cc9bcc34ea460d97fe",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "88236c6ea82247e9abf9e3584f922da0",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "87eaa7b2ae774b029ad52a889c0a1c13",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9fdf8740faae4b138fbce9b2259d2bc4",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "bb0b3c8911fa483d912f0a1417167a66",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "6acc094eec7946118400838d72eb9b9e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "e00ceadf7d8b4cb687690d529f1fe17e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9b797bf64e6446deab08fe9c3615bb2e",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "50724b39f84c46099cd9f15a4ff0a2f1",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "tokenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "48d0c234117e4548ab6eb3fc374ecf97",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "dd87ca0614274101b9f610f4f348b853",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "41e738328dd04ec999e32d843ee7aabc",
              "version_major": 2,
              "version_minor": 0
            },
            "text/plain": [
              "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
            ]
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "# Create a vector store with all the docs\n",
        "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)\n",
        "splits = text_splitter.split_documents(docs)\n",
        "vectorstore = Chroma.from_documents(documents=splits, embedding=HuggingFaceEmbeddings())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nId7fNy7Y6wq"
      },
      "outputs": [],
      "source": [
        "# Retrieve and generate using the relevant snippets of the blog.\n",
        "retriever = vectorstore.as_retriever()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2roNuR35-6CU"
      },
      "source": [
        "## Creating a RAG Chain"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jELJkevgAiXJ"
      },
      "source": [
        "Here's a resource to learn more about the LCEL paradigm: [the official documentation](https://python.langchain.com/v0.1/docs/expression_language/why/)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j3JCAT4H_Cv7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Prompt:\n",
            "\n",
            "You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\n",
            "Question: {question} \n",
            "Context: {context} \n",
            "Answer:\n"
          ]
        }
      ],
      "source": [
        "# Let's load a predefined prompt for this task\n",
        "prompt = hub.pull(\"rlm/rag-prompt\")\n",
        "print(f\"Prompt:\\n\\n{prompt.messages[0].prompt.template}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9xtJOe2LY6uz"
      },
      "outputs": [],
      "source": [
        "# Create an actual chain\n",
        "\n",
        "rag_chain = (\n",
        "    # First you need retrieve documents that are relevant to the\n",
        "    # given query\n",
        "    {\"context\": retriever | format_docs, \"question\": RunnablePassthrough()}\n",
        "    # The output is passed the prompt and fills fields like `{question}`\n",
        "    # and `{context}`\n",
        "    | prompt\n",
        "    # The whole prompt will all the information is passed the LLM\n",
        "    | llm\n",
        "    # The answer of the LLM is parsed by the class defined above\n",
        "    | GemmaOutputParser()\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yZT5H4yR-Z8P"
      },
      "source": [
        "## Let's try it out!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Xs8A6dTmiNtI"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.google.colaboratory.intrinsic+json": {
              "type": "string"
            },
            "text/plain": [
              "'Transformers are a state-of-the-art architecture for language modeling, designed around the concept of attention. They consist of an encoder and a decoder that convert input text into useful text, focusing on the most important parts of the input. Transformers are highly effective at various language tasks, including translation and text generation. \\n<end_of_turn>'"
            ]
          },
          "execution_count": 12,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "rag_chain.invoke(\"What are transformers?\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "brYMrZqwzyi2"
      },
      "source": [
        "# Extracting structured output (JSON)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TS6L7cwGfpwt"
      },
      "source": [
        "Traditionally, information extraction involved complex systems with hand-written rules and custom models, which were costly to maintain.\n",
        "\n",
        "Large Language Models (LLMs) offer a new approach. They can be instructed and given examples to perform specific extraction tasks, making them quicker to adapt and use.\n",
        "\n",
        "The following section will show you to use Gemma to extract information from a query using LangChain."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7VeaQyz1mH5j"
      },
      "outputs": [],
      "source": [
        "from langchain_core.output_parsers import JsonOutputParser\n",
        "from langchain_core.prompts import PromptTemplate\n",
        "from langchain_core.pydantic_v1 import BaseModel, Field"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ka60OmWYNS3q"
      },
      "source": [
        "## Implementing required steps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C48xtHCNAlgJ"
      },
      "outputs": [],
      "source": [
        "# Define the schema of the data you want to extract\n",
        "class Person(BaseModel):\n",
        "    name: str = Field(description=\"person's name\")\n",
        "    age: str = Field(description=\"person's age\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QXPOHZF4AnVS"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Schema passed to the LLM:\n",
            "{\n",
            "  name: <person's name>\n",
            "  age: <person's age>\n",
            "}\n"
          ]
        }
      ],
      "source": [
        "# Helpers\n",
        "def get_data_schema(pydantic):\n",
        "    \"\"\"A helper function that generates JSON schema that\n",
        "    the model can use to fill it with information\"\"\"\n",
        "    schema = {k: v for k, v in pydantic.schema().items()}\n",
        "    fields = [(k, v[\"description\"]) for k, v in schema[\"properties\"].items()]\n",
        "    json = \"\\n\".join(f\"  {name}: <{desc}>\" for (name, desc) in fields)\n",
        "    schema = \"{\\n\" + json + \"\\n}\"\n",
        "    return schema\n",
        "\n",
        "\n",
        "print(f\"Schema passed to the LLM:\\n{get_data_schema(Person)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1JskyClHkTE1"
      },
      "outputs": [],
      "source": [
        "# Define a prompt for the task explaining what needs to be done\n",
        "prompt_template = \"\"\"Extract data from the query to JSON format.\n",
        "Required schema:\\n{format_instructions}. Do not add new keys.\n",
        "\\n{query}\\n\"\"\"\n",
        "\n",
        "format_instructions = get_data_schema(Person)\n",
        "\n",
        "prompt = PromptTemplate(\n",
        "    template=prompt_template,\n",
        "    input_variables=[\"query\"],\n",
        "    partial_variables={\"format_instructions\": format_instructions},\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tKr1SWGMAv8Q"
      },
      "outputs": [],
      "source": [
        "# Let's create a chain that will tie all the parts together\n",
        "chain = prompt | llm | GemmaOutputParser() | JsonOutputParser()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8pBnbH8rBFl5"
      },
      "source": [
        "## Let's try it out!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sBTuzTEGvMWo"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'name': 'Kate', 'age': 26}"
            ]
          },
          "execution_count": 18,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "query = \"Kate is 26 years old and lives in Warsaw.\"\n",
        "chain.invoke({\"query\": query})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "u3NvUGfoDBCQ"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "{'name': 'Ben', 'age': 33}"
            ]
          },
          "execution_count": 19,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "query = \"\"\"In the midst of London's bustling streets, 33-year-old Ben\n",
        "weaved between double-decker buses. Fueled by a quick bite between\n",
        "museums, he was on a mission to absorb every corner of the city.\n",
        "This trip was a dream come true, and Ben couldn't wait to unearth\n",
        "the next hidden gem waiting to be discovered.\"\"\"\n",
        "chain.invoke({\"query\": query})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "30rve6k08gKc"
      },
      "source": [
        "# Using tools"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m7bqIMBKBfGl"
      },
      "source": [
        "There are two main reasons why tools are beneficial for LLMs (Large Language Models):\n",
        "\n",
        "- Enhanced Capabilities: LLMs are incredibly knowledgeable, but they can't access and process information in real-time the way a human can.  Tools like search engines and databases provide LLMs with a way to  find and integrate up-to-date information,  effectively extending their knowledge and abilities.  For instance, an LLM  could be writing a report, but it might need to access  specific statistics or research papers to complete the task.  By using  search tools, the LLM can  find this information and incorporate it into the report.\n",
        "\n",
        "- Real-World Interaction: LLMs themselves can't directly interact with the physical world.  However, tools like  programming interfaces (APIs)  allow LLMs to connect with and  control  various applications and devices.  This opens doors to a much wider range of applications,  like controlling smart home devices or generating code.\n",
        "\n",
        "In essence, tools bridge the gap between the vast knowledge stored within an LLM and the ability to use that knowledge in a practical way."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k50MyUwj76zt"
      },
      "outputs": [],
      "source": [
        "from langchain_core.tools import tool\n",
        "from langchain.agents import AgentExecutor, create_tool_calling_agent\n",
        "from langchain.tools.render import render_text_description\n",
        "from langchain_core.prompts import ChatPromptTemplate\n",
        "from operator import itemgetter"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hl48lxOzNgDK"
      },
      "source": [
        "## Implementing required steps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "60NZtMPl8nkH"
      },
      "outputs": [],
      "source": [
        "# Define set of tools that can be used by the LLM\n",
        "\n",
        "\n",
        "@tool\n",
        "def multiply(first_int: int, second_int: int) -> int:\n",
        "    \"\"\"Multiply two integers together.\n",
        "       (operators: mulitplied, *, times, etc.)\"\"\"\n",
        "    print(\"(tool called: multiply)\")\n",
        "    return first_int * second_int\n",
        "\n",
        "\n",
        "@tool\n",
        "def add(first_int: int, second_int: int) -> int:\n",
        "    \"\"\"Add two integers.\n",
        "       (operators: plus, added, +)\"\"\"\n",
        "    print(\"(tool called: add)\")\n",
        "    return first_int + second_int\n",
        "\n",
        "\n",
        "@tool\n",
        "def exponentiate(base: int, exponent: int) -> int:\n",
        "    \"\"\" Returns the value of `base` to the power of `exponent`\n",
        "       (operators: power to, **, exp)\"\"\"\n",
        "    print(\"(tool called: exponentiate)\")\n",
        "    return base**exponent\n",
        "\n",
        "\n",
        "tools = [add, exponentiate, multiply]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bOVeHscq9D1G"
      },
      "outputs": [],
      "source": [
        "# Helper\n",
        "def tool_chain(model_output):\n",
        "    \"\"\"A function that maps name of a tool to an actual\n",
        "    implementation and passes all the args\"\"\"\n",
        "    tool_map = {tool.name: tool for tool in tools}\n",
        "    chosen_tool = tool_map[model_output[\"name\"]]\n",
        "    return itemgetter(\"arguments\") | chosen_tool"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fGBEqPgz9Gmx"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Available tools:\n",
            "add(first_int: int, second_int: int) -> int - Add two integers.\n",
            "       (operators: plus, added, +)\n",
            "exponentiate(base: int, exponent: int) -> int - Returns the value of `base` to the power of `exponent`\n",
            "      (operators: power to, **, exp)\n",
            "multiply(first_int: int, second_int: int) -> int - Multiply two integers together.\n",
            "       (operators: mulitplied, *, times, etc.)\n"
          ]
        }
      ],
      "source": [
        "# LLM are text-based models so in order to inform the model\n",
        "# what tools can be used (and how) you need to describe them\n",
        "# using natural language\n",
        "rendered_tools = render_text_description(tools)\n",
        "print(f\"Available tools:\\n{rendered_tools}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "epLUP9ZZBzi0"
      },
      "outputs": [],
      "source": [
        "# Let's define prompot and inject the information about tools\n",
        "system_prompt = f\"\"\"You are an assistant that has access to the following set of tools. Here are the names and descriptions for each tool:\n",
        "\n",
        "{rendered_tools}\n",
        "\n",
        "Given the user input, return the name and input of the tool to use. Return your response as a JSON blob with 'name' and 'arguments' keys.\n",
        "Arguments should also be a JSON where the key is argument's name and the value is the value of that argument.\"\"\"\n",
        "\n",
        "prompt = ChatPromptTemplate.from_messages([(\"ai\", system_prompt), (\"user\", \"{input}\")])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zXeDhACiByE1"
      },
      "outputs": [],
      "source": [
        "# Define a chain that tie all parts together\n",
        "chain = prompt | llm | GemmaOutputParser() | JsonOutputParser() | tool_chain"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jcyM1IK6CWvx"
      },
      "source": [
        "## Let's try it out!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TJZc-As4-K-M"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(tool called: add)\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "1326"
            ]
          },
          "execution_count": 26,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "chain.invoke({\"input\": \"what's 3 plus 1323?\"})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Oa3mtl19-MNL"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(tool called: exponentiate)\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "64"
            ]
          },
          "execution_count": 27,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "chain.invoke({\"input\": \"what's 4 to the power or 3?\"})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Zv_MqikP-MKp"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "(tool called: multiply)\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              "25"
            ]
          },
          "execution_count": 28,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "chain.invoke({\"input\": \"what's 5 * 5?\"})"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_2]Using_with_LangChain.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
